# -*- coding: utf-8 -*-
import torch
from torch import nn
import torch.nn.functional as F


class ChnlAttention(nn.Module):
    def __init__(self, in_channels, ratio=16):
        super(ChnlAttention, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
        self.compress = nn.Conv2d(in_channels, in_channels // ratio, 1, 1, 0)
        self.excitation = nn.Conv2d(in_channels // ratio, in_channels, 1, 1, 0)

    def forward(self, x):
        out = self.squeeze(x)
        out = self.compress(out)
        out = F.relu(out)
        out = self.excitation(out)
        return F.sigmoid(out)


class CHNL_Block(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1):
        super(CHNL_Block, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel * self.expansion, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channel * self.expansion)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channel != out_channel * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channel * self.expansion)
            )
        self.chnl_atten = ChnlAttention(out_channel * self.expansion)

    def forward(self, x):
        out1 = self.layer(x)
        c = self.chnl_atten(out1)
        out1 = out1 * c
        out2 = self.shortcut(x)
        out = out1 + out2
        return out


if __name__ == '__main__':
    x = torch.randn(1, 16, 128, 64)  # b, c, h, w
    ca_model = CHNL_Block(in_channel=16)
    y = ca_model(x)
    print(y.shape)
